-
Notifications
You must be signed in to change notification settings - Fork 611
adding variable length attention to llama3 8b #2000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
eeecb63 to
cad97e5
Compare
fegin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation won't work with PP and too model intrusive. The pack logic should be hide inside the inner attention.
55352a5 to
066ca02
Compare
066ca02 to
c9b6d5c
Compare
fegin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the update. Leave some other comments, after the comments are addressed, this PR should be ready.
a902cbe to
de416f9
Compare
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Left some comments, please see if they make sense to you.
caafc81 to
4d36560
Compare
9380847 to
42c0c85
Compare
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some more comments. If you'd like to focus on Llama 3 in this PR, that's fine with me too.
5528029 to
31c1c77
Compare
fegin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, we can leave other models to other PR(s).
b717da3 to
9c99fcb
Compare
| xv, | ||
| self.head_dim, | ||
| attention_masks, | ||
| is_causal=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would fail? I think is_causal is no longer accepted.
Btw, it seems varlen is not tested in CI, can we add one test similar to https://github.com/pytorch/torchtitan/blob/main/tests/integration_tests/features.py#L336
1af38e5 to
df22636
Compare
df22636 to
2b1a40f
Compare
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
We need to modify save_list of SAC to save the result of varlen attn, to be consistent with other attn implementations. Can do this in next PR.
| [ | ||
| [ | ||
| "--parallelism.data_parallel_shard_degree=4", | ||
| "--activation_checkpoint.mode='full'", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use per_op_sac like the test above.
This reverts commit f8fa21e.
Summary
This PR adds variable length attention (varlen) support to the Llama 3 8b model in torchtitan. We replace
use_flex_attnwithattn_type(either "sdpa", "varlen", "flex"). Ifattn_type = "varlen", the attention module calls a compiledvarlen_attndefined here.Testing
Ran loss and performance tests against flex attention. Loss is on par.
Varlen is slightly slower than Flex due to the cuda kernel speeds (varlen calls into
flash_attention_forward/flash_attention_backwardtoday).